package com.yun.jing.common.config.encryption;


import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yun.jing.common.constant.Sm2Contant;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.MediaType;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.net.URLDecoder;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.Vector;

/**
 * 请求参数解密处理
 */
@Slf4j
public class DataEncryptionWrapper extends HttpServletRequestWrapper {

    private ApiEncryptServer encryptService;
    private final ObjectMapper objectMapper = new ObjectMapper();
    private String body = "";

    /**
     * 存储param，formdata，body等参数，用以获取参数的重新
     */
    private Map params = new HashMap();

    /**
     * 统一加密请求参数的键
     */
    private static final String REQ_BODY_KEY = "data";

    public DataEncryptionWrapper(HttpServletRequest request, ApiEncryptServer encryptService) throws IOException {
        super(request);
        this.encryptService = encryptService;
        String contentType = request.getContentType();
        if (StringUtils.isNotBlank(contentType)
                && (org.springframework.util.StringUtils.substringMatch(contentType, 0, MediaType.APPLICATION_FORM_URLENCODED_VALUE))
        ) {
            Map parametersMap = request.getParameterMap();
            if (parametersMap.containsKey(REQ_BODY_KEY)) {
                String data = ((String[]) parametersMap.get(REQ_BODY_KEY))[0];
                String deJson = EncryptUtils.getSm2Data(data, encryptService, Sm2Contant.privateKey);
                this.params.putAll(objectMapper.readValue(deJson, Map.class));
            }
            //将其他form中的放进去
            for (Object key : parametersMap.keySet()) {
                if (!key.equals(REQ_BODY_KEY)) {
                    this.params.put(key, parametersMap.get(key));
                }
            }
        } else if (StringUtils.isNotBlank(contentType)
                && org.springframework.util.StringUtils.substringMatch(contentType, 0, MediaType.APPLICATION_JSON_VALUE)) {
            StringBuilder stringBuilder = new StringBuilder();
            BufferedReader bufferedReader = null;
            try {
                InputStream inputStream = request.getInputStream();
                if (inputStream != null) {
                    bufferedReader = new BufferedReader(new InputStreamReader(inputStream, "UTF-8"));
                    char[] charBuffer = new char[128];
                    int bytesRead = -1;
                    while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
                        stringBuilder.append(charBuffer, 0, bytesRead);
                    }
                }
            } catch (IOException ex) {
                throw ex;
            } finally {
                if (bufferedReader != null) {
                    try {
                        bufferedReader.close();
                    } catch (IOException ex) {
                        throw ex;
                    }
                }
            }

            String requestURI = request.getRequestURI();
            //获取请求参数
            String queryString = request.getQueryString();
            log.info("====》 请求路径：" + requestURI + "，原请求query参数:{}", queryString);

            String bodyStr = stringBuilder.toString();
            bodyStr = bodyStr.startsWith("[") && bodyStr.endsWith("]") ? bodyStr.substring(1, bodyStr.length() - 1) : bodyStr;
            if (StringUtils.isNotBlank(bodyStr)) {
                JSONObject jsonObject = JSONObject.parseObject(bodyStr);
                if (jsonObject != null) {
                    //获取请求body
                    log.info("====》 请求路径：" + requestURI + "，原请求body参数体：{}", JSONObject.toJSONString(jsonObject, SerializerFeature.WriteMapNullValue));
                    String enJson = jsonObject.getString(REQ_BODY_KEY);
                    if (StringUtils.isNotBlank(enJson)) {
                        String deJson = EncryptUtils.getSm2Data(enJson, encryptService,Sm2Contant.privateKey);
                        body = deJson.startsWith("[") && deJson.endsWith("]") ? deJson.substring(1, deJson.length() - 1) : deJson;
                        this.params.putAll(objectMapper.readValue(body, Map.class));
                    }
                }
            }
        }

        // RequestDispatcher.forward parameter
        renewParameterMap(request);
    }

    @Override
    public String getParameter(String name) {
        String result = "";

        Object v = params.get(name);
        if (v == null && StringUtils.isNotBlank(body)) {
            JSONObject jsonObject = JSONObject.parseObject(body);
            v = jsonObject.get(name);
        }
        if (v == null) {
            result = null;
        } else if (v instanceof String[]) {
            String[] strArr = (String[]) v;
            if (strArr.length > 0) {
                try {
                    result = URLDecoder.decode(strArr[0], "utf-8");
                } catch (UnsupportedEncodingException e) {
                    e.printStackTrace();
                }
            } else {
                result = null;
            }
        } else if (v instanceof String) {
            result = (String) v;
            try {
                result = URLDecoder.decode(result, "utf-8");
            } catch (UnsupportedEncodingException e) {
                e.printStackTrace();
            }
        } else {
            result = v.toString();
        }
        return result;
    }

    @Override
    public Map getParameterMap() {
        return params;
    }

    @Override
    public Enumeration<String> getParameterNames() {
        return new Vector<String>(params.keySet()).elements();
    }

    @Override
    public String[] getParameterValues(String name) {
        String[] result = null;
        Object v = params.get(name);
        if (v == null && StringUtils.isNotBlank(body)) {
            JSONObject jsonObject = JSONObject.parseObject(body);
            v = jsonObject.get(name);
        }
        if (v == null) {
            result = null;
        } else if (v instanceof String[]) {
            result = (String[]) v;
            for (int i = 0; i < result.length; i++) {
                try {
                    result[i] = URLDecoder.decode(result[i], "utf-8");
                } catch (UnsupportedEncodingException e) {
                    e.printStackTrace();
                }
                //增加解密param参数的操作
                if (StringUtils.isNotBlank(result[i])) {
                    try {

                        result[i] = EncryptUtils.getSm2Data(result[i], encryptService,Sm2Contant.privateKey);
                    } catch (Exception e) {
//                        throw new BaseException("解密param请求参数失败");
                        System.out.println("解密param请求参数失败");
                    }
                }
            }
        } else if (v instanceof String) {
            try {
                result = new String[]{URLDecoder.decode((String) v, "utf-8")};
            } catch (UnsupportedEncodingException e) {
                e.printStackTrace();
            }
            //增加解密param参数的操作
            if (StringUtils.isNotBlank(result[0])) {
                try {

                    result[0] = EncryptUtils.getSm2Data(result[0], encryptService,Sm2Contant.privateKey);
                } catch (Exception e) {
//                    throw new BaseException("解密param请求参数失败");
                    System.out.println("解密param请求参数失败");

                }
            }
        } else {
            try {
                result = new String[]{URLDecoder.decode(v.toString(), "utf-8")};
            } catch (UnsupportedEncodingException e) {
                e.printStackTrace();
            }
            //增加解密param参数的操作
            if (StringUtils.isNotBlank(result[0])) {
                try {
                    result[0] = EncryptUtils.getSm2Data(result[0], encryptService,Sm2Contant.privateKey);
                } catch (Exception e) {
//                    throw new BaseException("解密param请求参数失败");
                    System.out.println("解密param请求参数失败");

                }
            }
        }
        return result;
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(body.getBytes("UTF-8"));
        ServletInputStream servletInputStream = new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

            @Override
            public int read() throws IOException {
                return byteArrayInputStream.read();
            }
        };
        return servletInputStream;
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(this.getInputStream(), "UTF-8"));
    }

    private void renewParameterMap(HttpServletRequest req) {

        String queryString = req.getQueryString();

        if (queryString != null && queryString.trim().length() > 0) {
            String[] params = queryString.split("&");

            for (int i = 0; i < params.length; i++) {
                int splitIndex = params[i].indexOf("=");
                if (splitIndex == -1) {
                    continue;
                }

                String key = params[i].substring(0, splitIndex);

                if (!this.params.containsKey(key)) {
                    if (splitIndex < params[i].length()) {
                        String value = params[i].substring(splitIndex + 1);
                        try {
                            this.params.put(key, new String[]{URLDecoder.decode(value, "utf-8")});
                        } catch (UnsupportedEncodingException e) {
                            e.printStackTrace();
                        }
                    }
                }
            }
        }
    }
}

